package dsl

import dsl.hdlbits_verilog_language.basics.*
import dsl.hdlbits_verilog_language.procedures.alwaysBlocksComb
import me.yricky.kaguya.base.ExpressionWire
import me.yricky.kaguya.base.NamedWire
import me.yricky.kaguya.base.Reg
import me.yricky.kaguya.compile.ICompiler
import me.yricky.kaguya.compile.verilog.KaguyaVerilogCompiler
import me.yricky.kaguya.dsl.*
import me.yricky.kaguya.dsl.utils.rShr
import me.yricky.kaguya.dsl.utils.time
import java.io.File
import kotlin.math.pow
import kotlin.test.Test

class ModuleTest {
    val file = File("generate")

    init {
        if(!file.exists()){
            file.mkdirs()
        }
    }
    @Test
    fun moduleDslTest(){
        val compiler:ICompiler = KaguyaVerilogCompiler()
        println(compiler.compile(wire()).also {
            File(file,"01wire.v").writeText(it)
        })
        println(compiler.compile(wire4()).also {
            File(file,"02wire4.v").writeText(it)
        })
        println(compiler.compile(notGate()).also {
            File(file,"03notgate.v").writeText(it)
        })
        println(compiler.compile(andGate()).also {
            File(file,"04andgate.v").writeText(it)
        })
        println(compiler.compile(norGate()).also {
            File(file,"05norgate.v").writeText(it)
        })
        println(compiler.compile(xnorGate()).also {
            File(file,"06xnorgate.v").writeText(it)
        })
        println(compiler.compile(wireDecline()).also {
            File(file,"07wire_decl.v").writeText(it)
        })
        println(compiler.compile(chip7458()).also {
            File(file,"08_7458.v").writeText(it)
        })
        println(compiler.compile(vector()).also {
            File(file,"09vector.v").writeText(it)
        })
        println(compiler.compile(vector1()).also {
            File(file,"10vector1.v").writeText(it)
        })
        println(compiler.compile(vector2()).also {
            File(file,"11vector2.v").writeText(it)
        })
        println(compiler.compile(vectorGates()).also {
            File(file,"12vectorgates.v").writeText(it)
        })
        println(compiler.compile(gates4()).also {
            File(file,"13gates4.v").writeText(it)
        })
        println(compiler.compile(vector3()).also {
            File(file,"14vector3.v").writeText(it)
        })

        println(compiler.compile(vectorr()).also {
            File(file,"15vectorr.v").writeText(it)
        })
        println(compiler.compile(vector4()).also {
            File(file,"16vector4.v").writeText(it)
        })
        println(compiler.compile(vector5()).also {
            File(file,"17vector5.v").writeText(it)
        })
        println(compiler.compile(module()).also {
            File(file,"18module.v").writeText(it)
        })
        println(compiler.compile(alwaysBlocksComb()).also {
            File(file,"alwaysBlock.v").writeText(it)
        })
    }

    /**
     * 构建一个每时钟自增的计数器
     * @param name 信号名称
     * @param clk 时钟信号
     * @param rstN 低电平复位信号
     * @param maxCount 计数器最大值
     * @return counter（位宽自动计算）和到达最大值时的触发信号
     */
    fun ModuleScope.counterReg(
        name:String,
        clk:NamedWire,
        rstN:NamedWire,
        maxCount:Long,
    ):Pair<Reg,ExpressionWire>{
        var width = 0
        var tmp = maxCount + 1
        while (tmp > 0){
            width++
            tmp /= 2
        }

        val counter = reg(name, width)
        time(clk,rstN){reset ->
            counter.from(
                reset to d(0, width),
                (counter lessThan d(maxCount + 1, width)) to (counter + d(1, width)),
                default = d(0, width)
            )
        }
        return Pair(counter,counter.eq(d(maxCount, width)))
    }

    @Test
    fun led(){
        module("led"){
            val clk = inputWire("sys_clk")
            val rstN = inputWire("sys_rst_n")
            val out = outputReg("led",6)
            val (_,counterTrig) = counterReg("count",clk,rstN,1349_9999)
            time(clk,rstN){reset ->
                out.from(
                    reset to h(1,6),
                    counterTrig to out.rShr(1),
                )
            }
        }.also {
            File("/home/yricky/Documents/gowin/led_test/src/test.v")
                .writeText(KaguyaVerilogCompiler().compile(it))
        }
    }


    fun pow(base:Int,pow:Int):Int{
        return base.toDouble().pow(pow).toInt()
    }

    @Test
    fun test256to1() {
        module("top_module"){
            val sig = inputWire("in",256)
            val sel = inputWire("sel",8)
            (0 until 8).map { bit ->
                joint(
                    sel[bit].repeat(pow(2,bit)),
                    sel[bit].inv().repeat(pow(2,bit))
                ).repeat(pow(2,7-bit))
            }.reduce{ s1,s2 -> s1 and s2 }.let {
                it and sig
            }.or().exportAsOutput("out")
        }.also {
            println(KaguyaVerilogCompiler().compile(it))
        }
    }

    @Test
    fun removeBracket(){
        println(with(KaguyaVerilogCompiler()){
            "((a)a(a))".removeBracket()
        })
    }
}